from collections import namedtuple

from itertools import product

from typing import Any, TypeVar

from axelrod.action import Action, actions_to_str, str_to_actions

from axelrod.evolvable_player import (
    EvolvablePlayer,
    InsufficientParametersError,
    crossover_dictionaries,
)

from axelrod.player import Player

C, D = Action.C, Action.D

actions = (C, D)

Plays = namedtuple("Plays", "self_plays, op_plays, op_openings")

Reaction = TypeVar("Reaction", Action, float)

def make_keys_into_plays(lookup_table: dict) -> dict:
    """Returns a dict where all keys are Plays."""
    new_table = lookup_table.copy()
    if any(not isinstance(key, Plays) for key in new_table):
        new_table = {Plays(*key): value for key, value in new_table.items()}
    return new_table

def create_lookup_table_keys(
    player_depth: int, op_depth: int, op_openings_depth: int
) -> list:
    """Returns a list of Plays that has all possible permutations of C's and
    D's for each specified depth. the list is in order,
    C < D sorted by ((player_tuple), (op_tuple), (op_openings_tuple)).
    create_lookup_keys(2, 1, 0) returns::

        [Plays(self_plays=(C, C), op_plays=(C,), op_openings=()),
         Plays(self_plays=(C, C), op_plays=(D,), op_openings=()),
         Plays(self_plays=(C, D), op_plays=(C,), op_openings=()),
         Plays(self_plays=(C, D), op_plays=(D,), op_openings=()),
         Plays(self_plays=(D, C), op_plays=(C,), op_openings=()),
         Plays(self_plays=(D, C), op_plays=(D,), op_openings=()),
         Plays(self_plays=(D, D), op_plays=(C,), op_openings=()),
         Plays(self_plays=(D, D), op_plays=(D,), op_openings=())]

    """
    self_plays = product((C, D), repeat=player_depth)
    op_plays = product((C, D), repeat=op_depth)
    op_openings = product((C, D), repeat=op_openings_depth)

    iterator = product(self_plays, op_plays, op_openings)
    return [Plays(*plays_tuple) for plays_tuple in iterator]

default_tft_lookup_table = {
    Plays(self_plays=(), op_plays=(D,), op_openings=()): D,
    Plays(self_plays=(), op_plays=(C,), op_openings=()): C,
}

def get_last_n_plays(player: Player, depth: int) -> tuple:
    """Returns the last N plays of player as a tuple."""
    if depth == 0:
        return ()
    return tuple(player.history[-1 * depth :])

class LookupTable(object):
    """
    LookerUp and its children use this object to determine their next actions.

    It is an object that creates a table of all possible plays to a specified
    depth and the action to be returned for each combination of plays.
    The "get" method returns the appropriate response.
    For the table containing::

        ....
        Plays(self_plays=(C, C), op_plays=(C, D), op_openings=(D, C): D
        Plays(self_plays=(C, C), op_plays=(C, D), op_openings=(D, D): C
        ...

    with:
    player.history[-2:]=[C, C] and
    opponent.history[-2:]=[C, D] and
    opponent.history[:2]=[D, D],
    calling LookupTable.get(plays=(C, C), op_plays=(C, D), op_openings=(D, D))
    will return C.

    Instantiate the table with a lookup_dict. This is
    {(self_plays_tuple, op_plays_tuple, op_openings_tuple): action, ...}.
    It must contain every possible
    permutation with C's and D's of the above tuple.  so::

        good_dict = {((C,), (C,), ()): C,
                     ((C,), (D,), ()): C,
                     ((D,), (C,), ()): D,
                     ((D,), (D,), ()): C}

        bad_dict = {((C,), (C,), ()): C,
                    ((C,), (D,), ()): C,
                    ((D,), (C,), ()): D}

    LookupTable.from_pattern() creates an ordered list of keys for you and maps
    the pattern to the keys.::

        LookupTable.from_pattern(pattern=(C, D, D, C),
            player_depth=0, op_depth=1, op_openings_depth=1
        )

    creates the dictionary::

        {Plays(self_plays=(), op_plays=(C), op_openings=(C)): C,
         Plays(self_plays=(), op_plays=(C), op_openings=(D)): D,
         Plays(self_plays=(), op_plays=(D), op_openings=(C)): D,
         Plays(self_plays=(), op_plays=(D), op_openings=(D)): C,}

    and then returns a LookupTable with that dictionary.
    """

    def __init__(self, lookup_dict: dict) -> None:
        self._dict = make_keys_into_plays(lookup_dict)

        sample_key = next(iter(self._dict))
        self._plays_depth = len(sample_key.self_plays)
        self._op_plays_depth = len(sample_key.op_plays)
        self._op_openings_depth = len(sample_key.op_openings)
        self._table_depth = max(
            self._plays_depth, self._op_plays_depth, self._op_openings_depth
        )
        self._raise_error_for_bad_lookup_dict()

    def _raise_error_for_bad_lookup_dict(self):
        if any(
            len(key.self_plays) != self._plays_depth
            or len(key.op_plays) != self._op_plays_depth
            or len(key.op_openings) != self._op_openings_depth
            for key in self._dict
        ):
            raise ValueError("Lookup table keys are not all the same size.")
        total_key_combinations = 2 ** (
            self._plays_depth + self._op_plays_depth + self._op_openings_depth
        )
        if total_key_combinations != len(self._dict):
            msg = (
                "Lookup table does not have enough keys"
                + " to cover all possibilities."
            )
            raise ValueError(msg)

    @classmethod
    def from_pattern(
        cls,
        pattern: tuple,
        player_depth: int,
        op_depth: int,
        op_openings_depth: int,
    ):
        keys = create_lookup_table_keys(
            player_depth=player_depth,
            op_depth=op_depth,
            op_openings_depth=op_openings_depth,
        )
        if len(keys) != len(pattern):
            msg = "Pattern must be len: {}, but was len: {}".format(
                len(keys), len(pattern)
            )
            raise ValueError(msg)
        input_dict = dict(zip(keys, pattern))
        return cls(input_dict)

    def get(self, plays: tuple, op_plays: tuple, op_openings: tuple) -> Any:
        return self._dict[
            Plays(self_plays=plays, op_plays=op_plays, op_openings=op_openings)
        ]

    @property
    def player_depth(self) -> int:
        return self._plays_depth

    @property
    def op_depth(self) -> int:
        return self._op_plays_depth

    @property
    def op_openings_depth(self) -> int:
        return self._op_openings_depth

    @property
    def table_depth(self) -> int:
        return self._table_depth

    @property
    def dictionary(self) -> dict:
        return self._dict.copy()

    def display(
        self, sort_by: tuple = ("op_openings", "self_plays", "op_plays")
    ) -> str:
        """
        Returns a string for printing lookup_table info in specified order.

        :param sort_by: only_elements='self_plays', 'op_plays', 'op_openings'
        """

        def sorter(plays):
            return tuple(
                actions_to_str(getattr(plays, field) for field in sort_by)
            )

        col_width = 11
        sorted_keys = sorted(self._dict, key=sorter)
        header_line = (
            "{str_list[0]:^{width}}|"
            + "{str_list[1]:^{width}}|"
            + "{str_list[2]:^{width}}"
        )
        display_line = header_line.replace("|", ",") + ": {str_list[3]},"

        def make_commaed_str(action_tuple):
            return ", ".join(str(action) for action in action_tuple)

        line_elements = [
            (
                make_commaed_str(getattr(key, sort_by[0])),
                make_commaed_str(getattr(key, sort_by[1])),
                make_commaed_str(getattr(key, sort_by[2])),
                self._dict[key],
            )
            for key in sorted_keys
        ]
        header = header_line.format(str_list=sort_by, width=col_width) + "\n"
        lines = [
            display_line.format(str_list=line, width=col_width)
            for line in line_elements
        ]
        return header + "\n".join(lines) + "\n"

    def __eq__(self, other) -> bool:
        if not isinstance(other, LookupTable):
            return False
        return self._dict == other.dictionary